In [ ]:
from __future__ import print_function, division

1. Visualisation with matplotlib


The library matplotlib is an extremely versatile graphics library written for Python in order to replicate and produce MATLAB type visualisations for engineering and scientific research. It is a low-level library and uses Python as the grammar for its graphics creation.

In this section of the course, we will learn about the Figure and Axes objects. These represent the backbone of matplotlib and all other objects are created on these two objects. We will learn how to generate the simplest chart - the "bread and butter" of scientific visualisation - line plots and scatter plots.

Then we will cover distributional plots in one and two dimensions. Density based plots are especially important when the number of data points are huge and overplotting becomes a problem.

Next we will learn how to create subplots, annotations and customize your plot.

Throughout this course, we will also learn and put some basic principles of data visualizations into practice.

1.1 Learning objectives

In this section we will:

  • Importing the matplotlib library and the %matplotlib inline magic command.
  • How to set up Figure and Axes objects.
  • Using the .plot method and its various keyword arguments.
  • Adding plot titles at the axis level.

1.2 Win, lose or DRAW!

By default, Jupyter Notebooks do not display matplotlib objects immediately after they are created. In order to activate the auto display function, we have to instruct Jupyter to do so by using the Ipython magic commands. In the cell below, run the command %matplotlib inline.


In [ ]:
%matplotlib inline

Now we need to import the matplotlib plotting library. We use the import and as keyword to do this. The keyword as allows us to replace the lengthy module name with an alias of our choosing. By convention, use plt.


In [ ]:
import matplotlib.pyplot as plt

plt.plot([1,2,3], [4,7, -1], 'bo')

Here above, you will see our very first plot using the plt.plot function: A simple scatter plot consisting of coordinates $$(1,4), (2,7), (3, -1)$$ on the $x$-$y$ plane. Notice the arguments passed to plt.plot.

  • A list of the $x$-coordinates, [1,2,3]. In general purpose usage, this is a Series or an array object containing the predictor variables of our data.

  • A list of the $y$-coordinates, [4,7,-1]. This will usually be the response variable.

  • The final string bo is a convenience syntax for instructing plt.plot to create a scatterplot using round (o) markers coloured blue (b).

Note that if not marker string is passed to .plot, matplotlib will output a line plot by default.

All this is very nice, but before we go into more details about plt.plot, we must set up the general workflow for working with this library.

1.3 The canvas: figure and axes

The parent object of any matplotlib visualisation is the Figure class. Figures contains the Axes objects which in turn contains all the other components that make up a plot: lines, patches, dots, text, arrows, contours and the like. Even objects like legends, colorbars, axis labels, ticks and plot titles are all subclassed from either a figure object or an axes object. Note that all visualisations produced by matplotlib are always contained in a Figure and axes. If these are not explicitly instanced, that they are implicitly created when plt.plot is called.

The matplotlib API can be interacted with using MATLAB style interfacing which uses function calls to plt.plot and various other annotation functions.

However, in this course, we wish to emphasize the object oriented interface to the API. This makes it easier to customize matplotlib and even seaborn visualisations. Essentially plot customization is about tweaking the attributes of an object.

To create a figure object:


In [ ]:
fig1 = plt.figure()

When run, we see that an instance of a Figure class has been created. However, we do not see anything more that that printed output. In fact, if we were using another Python IDE like Spyder or accessing Python from the command prompt (like IPython), we will see a blank pop window created. Since we are using Jupyter Notebooks, our particular graphical renderer (inline) will not display this.

From this point on, we can create the plots with desire using plt.plot and assign the output to a variable name. Copy the line in the cell above and paste it below.


In [ ]:
# This line will be pasted in by course attendee. 

plt.plot([1,2,3], [4,7,-1], '*r')

The reason for initializing a figure instance calling plt.plot in the same cell is so that axes for this plot (call it ax1) will be contained in fig1. By doing this we can save the figure to a file by running the following command

fig1.savefig("my_first_plot.jpeg", dpi=300) 

in possibly another cell. Since this figure contains that ax1 object, the saved jpeg file will contain the plot above. If the plot and figure instance were executed in different cells your jpeg file will be empty when fig1.savefig is called!

You may want to rerender your plot to adjust some parameters. By putting the .savefig method in another cell, you create a jpeg file output only once after you are satisfied with what you see (in Jupyter Notebook).


In [ ]:
fig1.savefig("my_first_plot.jpeg", dpi=300)

Check your current working directory. You should see the jpeg file there now. Congratulations! You have completed your first visualisation project!

2. Basic plotting with .plot method


The basic workflow of creating a visualisation consists on instancing a figure and axes class, getting data and plotting it and finally saving your output to a file. In this section, we will learn about matplotlib's versatile .plot method to create basic scatter and line plots.

The function of the .plot method is to create scatter and line plots as have been mentioned. You have seen how to create basic scatter plots, so lets work through a simple visualisation project using line plots. Along the way, we will learn about the various customization and keyword arguments to the .plot method.

For our first task, we want to compare stock prime movements of Apple Corp. (AAPL), 3M Corp. (MMM) and Google (GOOGL) from the 9th of November 2016. Stock price data can be obtained from Google and pandas has an extension API to get this data from Google. We download the data, clean it and calculate the percentage difference for one time interval.


In [ ]:
import pandas as pd
import pandas_datareader.data as web # requires pandas-datareader package
from datetime import datetime


start = datetime(2016,11,9)
# end = datetime(2017, 7, 24)
tickers = ["AAPL", "MMM", "GOOGL"]
stock_panel = web.DataReader(tickers, "google", start, )

closing_pct_change = (stock_panel.loc["Close", :, :]
                                 .pct_change()) 

# closing_pct_change.to_csv("closing_pct_change.csv")
# backup plan
# closing_pct_change = pd.read_csv("closing_pct_change.csv")

2.1 Using plt.plot to plot lines

Before we get to plotting the percentage difference in stock prices. Let's get a feel of how to plot lines in matplotlib. Like a scatterplot, to plot lines we:

  • Pass an array of x-coordinates and y-coordinates to the first and second arguments of plt.plot.
  • If we do not pass any other arguments, plt.plot thinks that you want to plot a line graph.

In [ ]:
# Use this space to do the following: Plot a simple line graph using plt.plot. Use the data below.

xcoord = [1,2,3,4]
ycoord = [5, -3, 9, -1]

# Answer:

In matplotlib, a line can also be created for a time series object by passing a Series object indexed by datetime objects. If you pass such a series as an argument to plt.plot, matplotlib is smart enough to know that you are trying to make a time series plot. In our data above, we want to plot stock price changes, which is a type of time series.

Let's give it a go first by writing

plt.plot(closing_pct_change["AAPL"], label="AAPL")
plt.legend(loc="best")

in the cell below. The code above instructs matplotlib to create a time series plot. Note to pass one data argument to plot, closing_pct_change["AAPL"]. The information for the x-axis is already contained in the index of that series.

Next, we pass the label keyword argument to plt.plot. Label is a way for us to "name" our line. It is needed if do plot a legend (which we will). When we call plt.legend, we create a legend on the axes. By passing loc keyword argument "best", we specify that the legend is to be placed in the best location as determined by matplotlib.


In [ ]:
# Copy and paste the code above in the space below

In matplotlib, an axes can contain many different such lines by repeatedly calling the .plot method on the same axis object. This has the effect of constructing multiple lines on the same plot and is useful for comparing various time series objects.

The code for accomplishing is below:

fig, ax = plt.subplots()

ax.plot(time_series_1, label=name_1)
ax.plot(time_series_2, label=name_2)
...
ax.plot(time_series_n, label=name_n)

We must first create a figure and axis object. And this can be accomplished together using plt.subplots() (and list unpacking).

However, look at the code above, is there anyway in which you can improve this?

2.2 Plotting multiple lines using the for loop

Yes, you probably guessed that we can use a for loop to repeatedly draw different time series objects on the same axes. By passing a unique name for each time series to the label keyword argument we are able to call legend to display the labels for each line.

In the cell below, we apply all that we have said above to plot the percentage change in closing stock price for Apple, 3M and Google on the same plot. We then place a legend on the plot and use .set_title method to write down a title for our plot.

The companies ticker codes are stored in a variable name tickers. We will use the ticker codes to access each column in the data frame by name.

The code to do all this is written below. Simply execute the cell below to see the results.


In [ ]:
fig2, ax2 = plt.subplots(figsize=(12,6)) # figure size is given in (wide, height) units

for ticker in tickers:
    ax2.plot(closing_pct_change[ticker], label=ticker)

plt.legend(loc="best")
ax2.set_title("Stock prices percentage change")

Do you notice that the variation between AAPL and GOOGL are highly correlated?

2.3 matplotlib marker codes

matplotlib employs a simple string code to allow us to customize the marker and linestyles of our plots. Refer to this page for the whole documentation.

2.3.1 Color codes

First of all, we can customize colors using single letter strings. For example 'b' would mean blue and 'r' means red.

Here is a table of supported color abbrevations used in matplotlib.

Color Code
'b' blue
'g' green
'r' red
'c' cyan
'm' magenta
'y' yellow
'k' black
'w' white

Besides this, you can pass your own customized colors using the color keyword argument to the plot function (or method). The code is color=x where x can be CSS colors codes (see here for full list), RGB (RGBA) tuples, hex strings ('#008000') or grayscale intensities (pass a single float).

2.3.2 Line and marker styles

Then we can customize line and marker styles in the same way. It is best to see just a few in action in the plot below. The full list can be found on the documentation page.


In [1]:
import matplotlib.pyplot as plt
import numpy as np 
from matplotlib.style import use
%matplotlib inline
use('ggplot')
xx = np.linspace(0,10,50)

fig, ax = plt.subplots(figsize=(12,12))
for m, marker in enumerate(['-', '--', ':', '-.', 'o', '.', 'v','^', '1', '2','3','4','s', '*', 'H', 'D', '+']):
    ax.plot(xx, (m+1)*xx+10, marker, label=marker)
ax.set_title("Marker and Line Style Codes", fontsize=12)
plt.legend(loc="best")


Out[1]:
<matplotlib.legend.Legend at 0x77ebef0>

2.3.3 Combining line styles with color abbrevations

The nice thing about plot is that it accepts as an argument the following convenience syntax: We can combine both line styles, marker styles and color into one single 2 (or 3 length string). Here are some examples.

  • '-or' means plot a solid line, with a circle marker at the data points in red.
  • 'bD' means plot a blue diamond marker scatter plot.

We use marker codes like this to tell plot whether to plot a line plot or a scatterplots. This is necessary because from plot point of view, there is no essential difference between one and another; a line plot (or time series plot) is created from a scatter plot by interpolating between succesive entries in the Series passed to plot.

Here below, we shall use this knowledge to plot a simple scatter plot to show the correlation between percentage change in stock prices of both Google (resp. 3M) and Apple.


In [ ]:
fig3, ax3 = plt.subplots(figsize=(6,6))

marker = ["go", "rD"]
for mark, ticker in zip(marker, tickers[1:]):
    ax3.plot(closing_pct_change["AAPL"], closing_pct_change[ticker], mark, label=ticker, alpha=0.5)
    
ax3.set_xlabel("AAPL Percentage change")
ax3.set_title("Correlation of Percentage changes in closing stock price")
plt.legend(loc="best")

Here we see that Google's stock price movements correlate with Apple's stock price movements. This is to be expected as both are technology companies. Compare this with the correlation between 3M and Apple. As 3M is a manufacturing company, this correlation is lower.

Demo: Currency exchange by country comparison project


Below we will use our knowledge of matplotlib to create a simple time series plot that compares the changes in currency exchange rates for some OECD and non-OECD countries. I intend to use this demo to show the whole visualisation workflow from getting data, cleaning it and visualising it. Our focus however, is on the visualisation code. If you are unfamiliar with pandas, treat the data cleaning part of the code as black boxes.


In [ ]:
# oecd_currency_xchange.to_csv("oecd_currency_xchange.csv")
import pandas as pd 
oecd_currency_xchange = pd.read_csv("oecd_currency_xchange.csv", infer_datetime_format=True, index_col=0, parse_dates=True )

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

n_cols = oecd_currency_xchange.shape[1]
color = sns.color_palette("tab20", n_cols)
non_oecd = ["Brazil", "China (People's Republic of)", "Colombia", "Costa Rica", "India", "South Africa"]

fig4, ax4 = plt.subplots(figsize=(12,12))

for _ in range(n_cols):
    plotSeries = oecd_currency_xchange.iloc[:, _]
    if plotSeries.name in non_oecd:
        marker = '--'
    else:
        marker = '-'
    ax4.plot(plotSeries, marker, label = plotSeries.name, color=color[_])

ax4.set_ylabel("Percent change from year 2000")
ax4.set_xlabel("Year")
ax4.set_title("Currency exchange rate per USD for some OECD and non-OECD countries\n2000 - present", fontsize=14)
plt.legend(loc="best");

In [ ]:
fig4.savefig("OECD.jpeg", dpi=300)